{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# torch.Tensor.register_hook(hook)  \n",
    "功能：注册一个反向传播hook函数，这个函数是Tensor类里的，当计算tensor的梯度时自动执行。   \n",
    "形式： hook(grad) -> Tensor or None ，其中grad就是这个tensor的梯度。  \n",
    "返回值：a handle that can be used to remove the added hook by calling handle.remove() "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 例1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gradient: tensor([5.]) tensor([2.]) None None None\n",
      "a_grad[0]:  tensor([2.])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "w = torch.tensor([1.], requires_grad=True)\n",
    "x = torch.tensor([2.], requires_grad=True)\n",
    "a = torch.add(w, x)  # retain_grad()\n",
    "b = torch.add(w, 1)\n",
    "y = torch.mul(a, b)\n",
    "a_grad = list()\n",
    "def grad_hook(grad):\n",
    "    a_grad.append(grad)\n",
    "\n",
    "handle = a.register_hook(grad_hook)\n",
    "\n",
    "y.backward()\n",
    "\n",
    "# 查看梯度\n",
    "print(\"gradient:\", w.grad, x.grad, a.grad, b.grad, y.grad)\n",
    "print(\"a_grad[0]: \", a_grad[0])\n",
    "handle.remove()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 例2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "w.grad:  tensor([10.])\n"
     ]
    }
   ],
   "source": [
    "w = torch.tensor([1.], requires_grad=True)\n",
    "x = torch.tensor([2.], requires_grad=True)\n",
    "a = torch.add(w, x)\n",
    "b = torch.add(w, 1)\n",
    "y = torch.mul(a, b)\n",
    "a_grad = list()\n",
    "def grad_hook(grad):\n",
    "    grad *= 2\n",
    "    # return grad*3\n",
    "handle = w.register_hook(grad_hook)\n",
    "y.backward()\n",
    "\n",
    "# 查看梯度\n",
    "print(\"w.grad: \", w.grad)\n",
    "handle.remove()"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# torch.nn.Module.register_forward_hook\n",
    "功能：Module前向传播中的hook,module在前向传播后，自动调用hook函数。  \n",
    "形式：hook(module, input, output) -> None。注意不能修改input和output的返回值。\n",
    "其中，module是当前网络层，input是网络层的输入数据, output是网络层的输出数据\n",
    "\n",
    "应用场景：如用于提取特征图  \n",
    "举例：假设网络由卷积层conv1和池化层pool1构成，输入一张4\\*4的图片，现采用forward_hook获取module——conv1之后的feature maps，示意图如下：\n",
    "![](hook-demo.png)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 2, 3)\n",
    "        self.pool1 = nn.MaxPool2d(2, 2)\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.pool1(x)\n",
    "        return x\n",
    "def farward_hook(module, data_input, data_output):\n",
    "    fmap_block.append(data_output)\n",
    "    input_block.append(data_input)\n",
    "    \n",
    "if __name__ == \"__main__\":\n",
    "    # 初始化网络\n",
    "    net = Net()\n",
    "    net.conv1.weight[0].fill_(1)\n",
    "    net.conv1.weight[1].fill_(2)\n",
    "    net.conv1.bias.data.zero_()\n",
    "    \n",
    "    # 注册hook\n",
    "    fmap_block = list()\n",
    "    input_block = list()\n",
    "    net.conv1.register_forward_hook(farward_hook)\n",
    "    \n",
    "    # inference\n",
    "    fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W\n",
    "    output = net(fake_img)\n",
    "    \n",
    "    # 观察\n",
    "    print(\"output shape: {}\\noutput value: {}\\n\".format(output.shape, output))\n",
    "    print(\"feature maps shape: {}\\noutput value: {}\\n\".format(fmap_block[0].shape, fmap_block[0]))\n",
    "    print(\"input shape: {}\\ninput value: {}\".format(input_block[0][0].shape, input_block[0]))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pt_1.11",
   "language": "python",
   "name": "pt_1.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}